import pymc as pm
import numpy as np
import pandas as pd
import arviz as az
import pytensor.tensor as pthello world
df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()| subject | x | dose | M | start | stop | event | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | ctrl | 6.74 | 0 | 4.00 | 0 |
| 1 | 1 | 0 | ctrl | 6.91 | 4 | 8.00 | 0 |
| 2 | 1 | 0 | ctrl | 6.90 | 8 | 12.00 | 0 |
| 3 | 1 | 0 | ctrl | 6.71 | 12 | 26.00 | 0 |
| 4 | 1 | 0 | ctrl | 6.45 | 26 | 46.85 | 1 |
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])| event | M | ||||
|---|---|---|---|---|---|
| mean | sum | mean | sum | ||
| x | dose | ||||
| 0 | ctrl | 0.164179 | 66 | 6.996915 | 2812.76 |
| 1 | high | 0.119205 | 54 | 8.081589 | 3660.96 |
| low | 0.139037 | 52 | 7.302620 | 2731.18 | |
Code
import matplotlib.pyplot as plt
import pandas as pd
# Derive subject-level info for ordering
subject_info = (
df.groupby('subject')
.agg(
x=('x', 'first'),
max_stop=('stop', 'max')
)
.sort_values(['x', 'max_stop'])
)
subjects = subject_info.index.tolist()
subject_to_y = {s: i for i, s in enumerate(subjects)}
fig, ax = plt.subplots(figsize=(8, 0.1 * len(subjects)))
for _, row in df.iterrows():
y = subject_to_y[row['subject']]
color = 'tab:blue' if row['x'] == 1 else 'tab:orange'
ax.hlines(
y=y,
xmin=row['start'],
xmax=row['stop'],
color=color,
linewidth=3
)
if row['event'] == 1:
ax.plot(
row['stop'],
y,
marker='o',
color='red',
markersize=6,
zorder=3
)
# Axis formatting
ax.set_yticks(range(len(subjects)))
ax.set_yticklabels(subjects)
ax.set_xlabel("Time")
ax.set_ylabel("Subject")
# Visual separation between treatment groups
x0_count = (subject_info['x'] == 0).sum()
ax.axhline(x0_count - 0.5, color='black', linestyle='--', linewidth=1)
# Legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='tab:blue', lw=3, label='x = 1'),
Line2D([0], [0], color='tab:orange', lw=3, label='x = 0'),
Line2D([0], [0], marker='o', color='red', lw=0, label='Event', markersize=6)
]
ax.legend(handles=legend_elements, loc='upper right')
ax.set_title("Subject Timelines Ordered by Treatment Level")
plt.tight_layout()
plt.show()Data Preparation
def prepare_aalen_dpa_data(
df,
subject_col="subject",
start_col="start",
stop_col="stop",
event_col="event",
x_col="x",
m_col="M",
):
"""
Prepare Andersen–Gill / Aalen dynamic path data for PyMC.
Parameters
----------
df : pd.DataFrame
Long-format start–stop survival data
subject_col : str
Subject identifier
start_col, stop_col : str
Interval boundaries
event_col : str
Event indicator (0/1)
x_col : str
Exposure / treatment
m_col : str
Mediator measured at interval start
Returns
-------
dict
Dictionary of numpy arrays ready for PyMC
"""
df = df.copy()
# -------------------------------------------------
# 1. Basic quantities
# -------------------------------------------------
df["dt"] = df[stop_col] - df[start_col]
if (df["dt"] <= 0).any():
raise ValueError("Non-positive interval lengths detected.")
N = df[event_col].astype(int).values
Y = np.ones(len(df), dtype=int) # Andersen–Gill at-risk indicator
# -------------------------------------------------
# 2. Time-bin indexing (piecewise-constant effects)
# -------------------------------------------------
bins = (
df[[start_col, stop_col]]
.drop_duplicates()
.sort_values([start_col, stop_col])
.reset_index(drop=True)
)
bins["bin_idx"] = np.arange(len(bins))
df = df.merge(
bins,
on=[start_col, stop_col],
how="left",
validate="many_to_one"
)
bin_idx = df["bin_idx"].values
n_bins = bins.shape[0]
# -------------------------------------------------
# 3. Center covariates (important for Aalen models)
# -------------------------------------------------
df["x_c"] = df[x_col]
df["m_c"] = df[m_col] - df[m_col].mean()
x = df["x_c"].values
m = df["m_c"].values
# -------------------------------------------------
# 4. Predictable mediator (lag within subject)
# -------------------------------------------------
df = df.sort_values([subject_col, start_col])
df["m_lag"] = (
df.groupby(subject_col)["m_c"]
.shift(1)
.fillna(0.0)
)
m_lag = df["m_lag"].values
df["I_low"] = (df["dose"] == "low").astype(int)
df["I_high"] = (df["dose"] == "high").astype(int)
# -------------------------------------------------
# 5. Assemble output
# -------------------------------------------------
data = {
"N": N,
"Y": Y,
"dt": df["dt"].values,
"bin_idx": bin_idx,
"x": x,
"m_lag": m_lag,
"n_bins": n_bins,
"bins": bins, # useful for plotting
"df_long": df # optional: debugging / inspection
}
return datadata = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)| subject | x | dose | M | event | dt | bin_idx | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | ctrl | 6.74 | 0 | 4.00 | 7 |
| 1 | 1 | 0 | ctrl | 6.91 | 0 | 4.00 | 13 |
| 2 | 1 | 0 | ctrl | 6.90 | 0 | 4.00 | 23 |
| 3 | 1 | 0 | ctrl | 6.71 | 0 | 14.00 | 53 |
| 4 | 1 | 0 | ctrl | 6.45 | 1 | 20.85 | 81 |
| 5 | 2 | 1 | high | 6.11 | 0 | 4.00 | 7 |
| 6 | 2 | 1 | high | 6.28 | 0 | 4.00 | 13 |
| 7 | 2 | 1 | high | 7.04 | 0 | 4.00 | 23 |
| 8 | 2 | 1 | high | 6.93 | 0 | 14.00 | 53 |
| 9 | 2 | 1 | high | 7.86 | 0 | 26.00 | 89 |
| 10 | 2 | 1 | high | 8.47 | 0 | 26.00 | 115 |
| 11 | 2 | 1 | high | 8.91 | 0 | 26.00 | 137 |
| 12 | 2 | 1 | high | 8.99 | 0 | 52.00 | 162 |
| 13 | 2 | 1 | high | 9.36 | 0 | 104.00 | 188 |
N = data["N"]
Y = data["Y"]
dt = data["dt"]
bin_idx = data["bin_idx"]
x = data["x"]
m_lag = data["m_lag"]
n_bins = data["n_bins"]
bins = data["bins"]
df_long = data["df_long"]
dt_bins = bins["stop"].values - bins["start"].values
m = df_long["m_c"].values
b = bin_idxfrom scipy.interpolate import BSpline
def create_bspline_basis(n_bins, n_knots=10, degree=3):
"""
Create B-spline basis functions for smooth time-varying effects.
Parameters
----------
n_bins : int
Number of time bins
n_knots : int
Number of internal knots (fewer = smoother)
degree : int
Degree of spline (3 = cubic, recommended)
Returns
-------
basis : np.ndarray
Matrix of shape (n_bins, n_basis) with basis function values
"""
# Create knot sequence
# Internal knots equally spaced across time range
internal_knots = np.linspace(0, n_bins-1, n_knots)
# Add boundary knots (repeated degree+1 times for clamped spline)
knots = np.concatenate([
np.repeat(internal_knots[0], degree),
internal_knots,
np.repeat(internal_knots[-1], degree)
])
# Number of basis functions
n_basis = len(knots) - degree - 1
# Evaluate each basis function at each time point
t = np.arange(n_bins, dtype=float)
basis = np.zeros((n_bins, n_basis))
for i in range(n_basis):
# Create coefficient vector (indicator for basis i)
coef = np.zeros(n_basis)
coef[i] = 1.0
# Evaluate B-spline
spline = BSpline(knots, coef, degree, extrapolate=False)
basis[:, i] = spline(t)
return basis
n_knots = 10
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)| feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1 | 0.863149 | 0.133496 | 0.003337 | 0.000018 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 2 | 0.739389 | 0.247518 | 0.012946 | 0.000146 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 3 | 0.628064 | 0.343219 | 0.028223 | 0.000494 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 4 | 0.528515 | 0.421749 | 0.048566 | 0.001170 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 5 | 0.440083 | 0.484261 | 0.073370 | 0.002286 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 6 | 0.362110 | 0.531908 | 0.102032 | 0.003950 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 7 | 0.293939 | 0.565840 | 0.133949 | 0.006272 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 8 | 0.234909 | 0.587211 | 0.168518 | 0.009362 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 9 | 0.184365 | 0.597171 | 0.205134 | 0.013330 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
n_basis = basis.shape[1]
n_obs = data['df_long'].shape[0]
time_bins = data['bins']['bin_idx'].values
observed_mediator = df_long["m_c"].values
observed_events = df_long['event'].astype(int).values
observed_treatment = df_long['x'].astype(int).values
observed_mediator_lag = df_long['m_lag'].values
coords = {'tv': ['intercept', 'direct', 'mediator'],
'splines': ['spline_f_{i}' for i in range(n_basis)],
'obs': range(n_obs),
'time_bins': time_bins}
with pm.Model(coords=coords) as aalen_dpa_model:
trt = pm.Data("trt", observed_treatment, dims="obs")
med = pm.Data("mediator", observed_mediator, dims="obs")
med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
events = pm.Data("events", observed_events, dims="obs")
I_low = pm.Data("I_low", df_long["I_low"].values, dims="obs")
I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
dt = pm.Data("duration", df_long['dt'].values, dims='obs')
## because our long data format has a cell per obs
at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
basis_ = pm.Data("basis", basis_df.values, dims=('time_bins', 'splines') )
# -------------------------------------------------
# 1. B-spline coefficients for HAZARD model
# -------------------------------------------------
# Prior on spline coefficients
# Smaller sigma = less wiggliness
# Random Walk 1 (RW1) Prior for coefficients
# This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))
# Cumulative sum makes it a Random Walk
# This ensures coefficients evolve smoothly over time
coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))
# Construct smooth time-varying functions
alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
# -------------------------------------------------
# 2. B-spline coefficients for MEDIATOR model
# -------------------------------------------------
sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
beta_t = pt.dot(basis_, coef_beta)
# -------------------------------------------------
# 3. Mediator model (A path: x → M)
# -------------------------------------------------
sigma_m = pm.HalfNormal("sigma_m", 1.0)
# Autoregressive component
rho = pm.Beta("rho", 2, 2)
mu_m = beta_t[b] * trt + rho * med_lag
pm.Normal(
"obs_m",
mu=mu_m,
sigma=sigma_m,
observed=med,
dims='obs'
)
# -------------------------------------------------
# 4. Hazard model (direct + B path)
# -------------------------------------------------
beta_low = pm.Normal("beta_low", 0, 0.1)
beta_high = pm.Normal("beta_high", 0, 0.1)
# Log-additive hazard
log_lambda_t = (alpha_0_t[b]
+ alpha_1_t[b] * trt # direct effect
+ alpha_2_t[b] * med # mediator effect
+ beta_low * I_low
+ beta_high * I_high
)
# Expected number of events
time_at_risk = at_risk * dt
Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)
pm.Poisson(
"obs_event",
mu=Lambda,
observed=events,
dims='obs'
)
# -------------------------------------------------
# 5. Causal path effects
# -------------------------------------------------
# Store time-varying coefficients
pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins') # direct effect
pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins') # B path
pm.Deterministic("beta_t", beta_t, dims='time_bins') # A path
# Cumulative direct effect
cum_de = pm.Deterministic(
"tv_direct_effect",
alpha_1_t,
dims='time_bins'
)
# Cumulative indirect effect (product of paths)
cum_ie = pm.Deterministic(
"tv_indirect_effect",
beta_t * alpha_2_t,
dims='time_bins'
)
# Total effect
cum_te = pm.Deterministic(
"tv_total_effect",
cum_de + cum_ie,
dims='time_bins'
)
# -------------------------------------------------
# 6. Sample
# -------------------------------------------------
trace_additive = pm.sample(
draws=2000,
tune=2000,
target_accept=0.95,
chains=4,
nuts_sampler="numpyro",
random_seed=42,
init="adapt_diag"
)There were 12 divergences after tuning. Increase `target_accept` or reparameterize.
pm.model_to_graphviz(aalen_dpa_model)vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']
fig, axs = plt.subplots(1, 3, figsize=(20, 6))
for i, var in enumerate(vars_to_plot):
# 1. Extract the posterior samples for this variable
# Shape will be (chain * draw, time)
post_samples = az.extract(trace_additive, var_names=[var]).values.T
# 2. Calculate the mean and the 94% HDI across the chains/draws
mean_val = post_samples.mean(axis=0)
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
# 3. Plot the Mean line
x_axis = np.arange(len(mean_val))
axs[i].plot(x_axis, mean_val, label=labels[i], color='teal', lw=2)
# 4. Plot the Shaded HDI region
axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color='teal', alpha=0.2, label='94% HDI')
# Formatting
axs[i].set_title(labels[i])
axs[i].legend()
axs[i].grid(alpha=0.3)
# If you have a time vector (e.g., days), replace x_axis with your time values
plt.tight_layout()
plt.show()/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
Citation
BibTeX citation:
@online{forde,
author = {Forde, Nathaniel},
title = {Aalen’s {Dynamic} {Path} {Model}},
langid = {en}
}
For attribution, please cite this work as:
Forde, Nathaniel. n.d. “Aalen’s Dynamic Path Model.”